#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import requests

SHARD_TASK_NAME = "test-shard"
ILLEGAL_CHAR_TASK_NAME = "t-Ë!s`t"
SOURCE1_NAME = "mysql-01"
SOURCE2_NAME = "mysql-02"


API_ENDPOINT = "http://127.0.0.1:8361/api/v1/tasks"


def create_task_failed():
    task = {
        "name": "test",
        "task_mode": "all",
        "shard_mode": "pessimistic_xxd",  # pessimistic_xxd is not a valid shard mode
        "meta_schema": "dm-meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
            },
        ],
        "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
    }
    resp = requests.post(url=API_ENDPOINT, json={"task": task})
    print("create_task_failed resp=", resp.json())
    assert resp.status_code == 400

def create_task_with_precheck(task_name, ignore_check, is_success, check_result):
    task = {
        "name": task_name,
        "task_mode": "all",
        "meta_schema": "dm_meta",
        "shard_mode": "pessimistic",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "ignore_checking_items" : ["version", ignore_check],
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "t*",
                },
                "target": {"schema": "openapi", "table": "t"},
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "t*",
                },
                "target": {"schema": "openapi", "table": "t"},
            },
        ],
        "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
    }

    resp = requests.post(url=API_ENDPOINT, json={"task": task})
    print("create_task_with_precheck resp=", resp.json())
    if is_success == "success":
         assert resp.status_code == 201
         assert check_result in resp.json()["check_result"]
    else:
        assert resp.status_code == 400
        assert check_result in resp.json()["error_msg"]

def create_noshard_task_success(task_name, tartget_table_name=""):
    task = {
        "name": task_name,
        "task_mode": "all",
        "meta_schema": "dm-meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": tartget_table_name},
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": tartget_table_name},
            },
        ],
        "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
    }
    resp = requests.post(url=API_ENDPOINT, json={"task": task})
    print("create_noshard_task_success resp=", resp.json())
    assert resp.status_code == 201

def create_incremental_task_with_gtid_success(task_name,binlog_name1,binlog_pos1,binlog_gtid1,binlog_name2,binlog_pos2,binlog_gtid2):
    task = {
        "name": task_name,
        "task_mode": "incremental",
        "meta_schema": "dm_meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": ""},
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": ""},
            },
        ],
         "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
         },
    }

    if binlog_pos1 != "":
        task["source_config"] = {
            "source_conf": [
                {
                    "source_name": SOURCE1_NAME,
                    "binlog_name": binlog_name1,
                    "binlog_pos": int(binlog_pos1),
                    "binlog_gtid": binlog_gtid1,
                },
                {
                    "source_name": SOURCE2_NAME,
                    "binlog_name": binlog_name2,
                    "binlog_pos": int(binlog_pos2),
                    "binlog_gtid": binlog_gtid2,
                },
            ],
        }

    resp = requests.post(url=API_ENDPOINT, json={"task": task})
    print("create_incremental_task_with_gtid_success resp=", resp.json())
    assert resp.status_code == 201

def create_shard_task_success():
    task = {
        "name": SHARD_TASK_NAME,
        "task_mode": "all",
        "shard_mode": "pessimistic",
        "meta_schema": "dm-meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
                "binlog_filter_rule": ["rule-1"],
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
                "binlog_filter_rule": ["rule-2"],
            },
        ],
        "source_config": {
            "full_migrate_conf": {
                "export_threads": 4,
                "import_threads": 16,
                "data_dir": "./exported_data",
                "consistency": "auto",
            },
            "incr_migrate_conf": {"repl_threads": 16, "repl_batch": 100},
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
        "binlog_filter_rule": {
            "rule-1": {
                "ignore_event": ["delete"],
            },
            "rule-2": {
                "ignore_sql": ["alter table openapi.t add column aaa int;"],
            },
        },
    }
    resp = requests.post(url=API_ENDPOINT, json={"task": task})
    print("create_shard_task_success resp=", resp.json())
    assert resp.status_code == 201


def start_task_success(task_name, source_name):
    url = API_ENDPOINT + "/" + task_name + "/start"
    req = {}
    if source_name != "":
        req = {"source_name_list": [source_name], "remove_meta": True}
    resp = requests.post(url=url, json=req)
    if resp.status_code != 200:
        print("start_task_failed resp=", resp.json())
    assert resp.status_code == 200

def start_task_failed(task_name, source_name, check_result):
    url = API_ENDPOINT + "/" + task_name + "/start"
    req = {}
    if source_name != "":
        req = {"source_name_list": [source_name], "remove_meta": True}
    resp = requests.post(url=url, json=req)
    if resp.status_code == 200:
         print("start_task_failed resp should not be 200")
    assert resp.status_code == 400
    print("start_task_failed resp=", resp.json())
    assert check_result in resp.json()["error_msg"]

def start_task_with_condition(task_name, start_time, duration, is_success, check_result):
    url = API_ENDPOINT + "/" + task_name + "/start"
    req = {"start_time": start_time, "safe_mode_time_duration": duration}
    resp = requests.post(url=url, json=req)
    if is_success == "success":
        if resp.status_code != 200:
            print("start_task_with_condition_failed resp=", resp.json())
        assert resp.status_code == 200
        print("start_task_with_condition success")
    else:
        if resp.status_code == 200:
             print("start_task_with_condition_failed resp should not be 200")
        assert resp.status_code == 400
        print("start_task_with_condition resp=", resp.json())
        assert check_result in resp.json()["error_msg"]

def stop_task_with_condition(task_name, source_name, timeout_duration):
    url = API_ENDPOINT + "/" + task_name + "/stop"
    req = {"timeout_duration": timeout_duration}
    if source_name != "":
        req = {"source_name_list": [source_name], "timeout_duration": timeout_duration}
    resp = requests.post(url=url, json=req)
    if resp.status_code != 200:
        print("stop_task_failed resp=", resp.json())
    assert resp.status_code == 200

def stop_task_success(task_name, source_name):
    url = API_ENDPOINT + "/" + task_name + "/stop"
    req = {}
    if source_name != "":
        req = {"source_name_list": [source_name]}
    resp = requests.post(url=url, json=req)
    if resp.status_code != 200:
        print("stop_task_failed resp=", resp.json())
    assert resp.status_code == 200

def delete_task_success(task_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + task_name)
    assert resp.status_code == 204
    print("delete_task_success")


def delete_task_failed(task_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + task_name)
    print("delete_task_failed resp=", resp.json())
    assert resp.status_code == 400


def delete_task_with_force_success(task_name):
    resp = requests.delete(url=API_ENDPOINT + "/" + task_name + "?force=true")
    assert resp.status_code == 204
    print("delete_task_success")


def get_task_status_failed(task_name):
    url = API_ENDPOINT + "/" + task_name + "/status"
    resp = requests.get(url=url)
    print("get_task_status_failed resp=", resp.json())
    assert resp.status_code == 400


def get_illegal_char_task_status_failed():
    # task name contains illegal char but api server can handle it.
    # return 400 is because of the task is not started.
    url = API_ENDPOINT + "/" + ILLEGAL_CHAR_TASK_NAME + "/status"
    resp = requests.get(url=url)
    print("get_illegal_char_task_status_failed resp=", resp.json())
    assert resp.status_code == 400
    if sys.version_info.major == 2:
        # need decode in python2
        assert ILLEGAL_CHAR_TASK_NAME.decode("utf-8") in resp.json()["error_msg"]
    else:
        assert ILLEGAL_CHAR_TASK_NAME in resp.json()["error_msg"]


def get_task_status_success(task_name, total):
    url = API_ENDPOINT + "/" + task_name + "/status"
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("get_task_status_success resp=", data)
    assert data["total"] == int(total)


def get_task_status_success_but_worker_meet_error(task_name, total):
    url = API_ENDPOINT + "/" + task_name + "/status"
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("get_task_status_success_but_worker_meet_error resp=", data)
    assert data["total"] == int(total)
    for status in data["data"]:
        assert status["name"] == task_name
        assert status["error_msg"] is not None


def get_task_list(task_count):
    url = API_ENDPOINT
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("get_task_list resp=", data)
    assert data["total"] == int(task_count)


def get_task_list_with_status(task_count, task_name, status_count):
    url = API_ENDPOINT + "?with_status=true"
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("get_task_list_with_status resp=", data)

    assert data["total"] == int(task_count)
    find_task = False
    for task in data["data"]:
        if task["name"] == task_name:
            find_task = True
            assert len(task["status_list"]) == int(status_count)
    assert find_task


def operate_schema_and_table_success(task_name, source_name, schema_name, table_name):
    schema_url = API_ENDPOINT + "/" + task_name + "/sources/" + source_name + "/schemas"
    schema_resp = requests.get(url=schema_url)
    assert schema_resp.status_code == 200
    print("get_task_schema_success schema resp=", schema_resp.json())
    assert len(schema_resp.json()) > 0

    schema_list = schema_resp.json()
    assert schema_name in schema_list
    table_url = schema_url + "/" + schema_name
    table_resp = requests.get(url=table_url)
    assert table_resp.status_code == 200
    print("get_task_schema_success table resp=", table_resp.json())
    table_list = table_resp.json()
    assert table_name in table_list

    single_table_url = table_url + "/" + table_name
    create_table_resp = requests.get(url=single_table_url)
    assert create_table_resp.status_code == 200
    create_table = create_table_resp.json()
    print("get_task_schema_success create table resp=", create_table)
    assert create_table["table_name"] == table_name
    assert create_table["schema_name"] == schema_name
    assert table_name in create_table["schema_create_sql"]

    # overwrite table
    set_table_data = {
        "sql_content": "CREATE TABLE openapi.t1(i TINYINT, j INT UNIQUE KEY);",
        "flush": True,
        "sync": True,
    }
    resp = requests.put(url=single_table_url, json=set_table_data)
    assert resp.status_code == 200
    table_list = requests.get(url=table_url).json()
    print("get_task_schema_success table resp=", table_list)
    assert len(table_list) == 1


def create_task_template_success(task_name, target_table_name=""):
    url = API_ENDPOINT + "/templates"
    task = {
        "name": task_name,
        "task_mode": "all",
        "shard_mode": "pessimistic",
        "meta_schema": "dm-meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": target_table_name},
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": target_table_name},
            },
        ],
        "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
    }
    resp = requests.post(url=url, json=task)
    print("create_task_template_success resp=", resp.json())
    assert resp.status_code == 201


def create_task_template_failed():
    url = API_ENDPOINT + "/templates"
    task = {
        "name": "test",
        "task_mode": "all",
        "shard_mode": "pessimistic_xxd",  # pessimistic_xxd is not a valid shard mode
        "meta_schema": "dm-meta",
        "enhance_online_schema_change": True,
        "on_duplicate": "error",
        "target_config": {
            "host": "127.0.0.1",
            "port": 4000,
            "user": "root",
            "password": "",
        },
        "table_migrate_rule": [
            {
                "source": {
                    "source_name": SOURCE1_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
            },
            {
                "source": {
                    "source_name": SOURCE2_NAME,
                    "schema": "openapi",
                    "table": "*",
                },
                "target": {"schema": "openapi", "table": "t"},
            },
        ],
        "source_config": {
            "source_conf": [
                {"source_name": SOURCE1_NAME},
                {"source_name": SOURCE2_NAME},
            ],
        },
    }
    resp = requests.post(url=url, json=task)
    print("create_task_template_failed resp=", resp.json())
    assert resp.status_code == 400


def list_task_template(count):
    url = API_ENDPOINT + "/templates"
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("list_task_template resp=", data)
    assert data["total"] == int(count)


def import_task_template(success_count, failed_count):
    url = API_ENDPOINT + "/templates/import"
    resp = requests.post(url=url)
    data = resp.json()
    print("import_task_template resp=", data)
    assert resp.status_code == 202
    assert len(data["success_task_list"]) == int(success_count)
    assert len(data["failed_task_list"]) == int(failed_count)


def get_task_template(name):
    url = API_ENDPOINT + "/templates/" + name
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("get_task_template resp=", data)
    assert data["name"] == name


def update_task_template_success(name, task_mode):
    url = API_ENDPOINT + "/templates/" + name

    # get task template first
    task = requests.get(url=url).json()
    task["task_mode"] = task_mode
    resp = requests.put(url=url, json=task)
    print("update_task_template_success resp=", resp.json())
    assert resp.status_code == 200

    # update task template success
    assert requests.get(url=url).json()["task_mode"] == task_mode


def delete_task_template(name):
    url = API_ENDPOINT + "/templates/" + name
    resp = requests.delete(url=url)
    assert resp.status_code == 204
    print("delete_task_template")


def check_noshard_task_dump_status_success(task_name, total):
    url = API_ENDPOINT + "/" + task_name + "/status"
    resp = requests.get(url=url)
    data = resp.json()
    assert resp.status_code == 200
    print("check_dump_status_success resp=", data)
    assert data["data"][0]["dump_status"]["finished_bytes"] == int(total)


def do_complex_operations(task_name):
    source1_url = "http://127.0.0.1:8361/api/v1/sources/" + SOURCE1_NAME
    task_url = "http://127.0.0.1:8361/api/v1/tasks/" + task_name
    enable_source_url = source1_url + "/enable"
    disable_source_url = source1_url + "/disable"

    stop_url = task_url + "/stop"
    start_url = task_url + "/start"
    status_url = task_url + "/status"
    migrate_targets_url = task_url + "/sources/" + SOURCE1_NAME + "/migrate_targets"

    # get source
    source = requests.get(source1_url).json()
    # update source failed
    update_source_req = {"source": source}
    resp = requests.put(source1_url, json=update_source_req)
    assert resp.status_code == 400

    # get task
    task = requests.get(task_url).json()

    # update task failed
    update_task_req = {"task": task}
    resp = requests.put(task_url, json=update_task_req)
    assert resp.status_code == 400

    # stop task
    resp = requests.post(stop_url)
    status = requests.get(status_url).json()
    for s in status["data"]:
        assert s["stage"] == "Paused"

    # update task success
    task["source_config"]["incr_migrate_conf"]["repl_threads"] = 1
    update_task_req = {"task": task}
    resp = requests.put(task_url, json=update_task_req)
    if resp.status_code != 200:
        print("update task failed", resp.json())
    assert resp.status_code == 200
    task_after_updated = requests.get(task_url).json()
    assert task_after_updated["source_config"]["incr_migrate_conf"]["repl_threads"] == 1

    # start again
    resp = requests.post(start_url)
    if resp.status_code != 200:
        print("start task failed", resp.json())
    assert resp.status_code == 200
    status = requests.get(status_url).json()
    for s in status["data"]:
        assert s["stage"] == "Running"

    # disable source1, subtask on source is paused
    resp = requests.post(disable_source_url)
    if resp.status_code != 200:
        print("disable source failed", resp.json())
    assert resp.status_code == 200
    status = requests.get(status_url).json()
    for s in status["data"]:
        if s["source_name"] == SOURCE1_NAME:
            assert s["stage"] == "Paused"
        else:
            assert s["stage"] == "Running"

    # update source1 success
    source["enable_gtid"] = True
    source["password"] = "123456"
    resp = requests.put(source1_url, json=update_source_req)
    if resp.status_code != 200:
        print("update source failed", resp.json())
    assert resp.status_code == 200

    # enable source task will start again
    resp = requests.post(enable_source_url)
    if resp.status_code != 200:
        print("enable source failed", resp.json())
    assert resp.status_code == 200
    status = requests.get(status_url).json()
    for s in status["data"]:
        assert s["stage"] == "Running"

    # list migrate targets
    source1_migrate_rules = []
    for rule in task["table_migrate_rule"]:
        if rule["source"]["source_name"] == SOURCE1_NAME:
            source1_migrate_rules.append(rule)

    resp = requests.get(migrate_targets_url)
    if resp.status_code != 200:
        print("list migrate targets failed", resp.json())
    assert resp.status_code == 200
    data = resp.json()
    assert data["total"] == len(source1_migrate_rules)
    assert (
        data["data"][0]["source_schema"] == source1_migrate_rules[0]["source"]["schema"]
    )
    assert (
        data["data"][0]["target_schema"] == source1_migrate_rules[0]["target"]["schema"]
    )


if __name__ == "__main__":
    FUNC_MAP = {
        "create_task_failed": create_task_failed,
        "create_noshard_task_success": create_noshard_task_success,
        "create_shard_task_success": create_shard_task_success,
        "create_incremental_task_with_gtid_success": create_incremental_task_with_gtid_success,
        "delete_task_failed": delete_task_failed,
        "delete_task_success": delete_task_success,
        "delete_task_with_force_success": delete_task_with_force_success,
        "start_task_success": start_task_success,
        "start_task_failed": start_task_failed,
        "start_task_with_condition": start_task_with_condition,
        "stop_task_with_condition": stop_task_with_condition,
        "stop_task_success": stop_task_success,
        "get_task_list": get_task_list,
        "get_task_list_with_status": get_task_list_with_status,
        "get_task_status_failed": get_task_status_failed,
        "get_illegal_char_task_status_failed": get_illegal_char_task_status_failed,
        "get_task_status_success": get_task_status_success,
        "get_task_status_success_but_worker_meet_error": get_task_status_success_but_worker_meet_error,
        "operate_schema_and_table_success": operate_schema_and_table_success,
        "create_task_template_success": create_task_template_success,
        "create_task_template_failed": create_task_template_failed,
        "list_task_template": list_task_template,
        "import_task_template": import_task_template,
        "get_task_template": get_task_template,
        "update_task_template_success": update_task_template_success,
        "delete_task_template": delete_task_template,
        "check_noshard_task_dump_status_success": check_noshard_task_dump_status_success,
        "do_complex_operations": do_complex_operations,
        "create_task_with_precheck": create_task_with_precheck,
    }

    func = FUNC_MAP[sys.argv[1]]
    if len(sys.argv) >= 2:
        func(*sys.argv[2:])
    else:
        func()
